import numpy as np
import matplotlib.pyplot as plt
from scipy.special import wofz
from matplotlib.colors import LogNorm
import matplotlib.cm as cm

# --- Plotting Setup ---
plt.rcParams.update({
    "text.usetex": True,  # set False if LaTeX isn't installed
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "font.size": 14,
    "axes.titlesize": 16,
    "axes.labelsize": 14,
})

# --- Core Physics Functions ---

def f_norm(Omega):
    """
    f(Ω) = exp(-πΩ/2) / sqrt(8π |Ω| sinh(π|Ω|))
    """
    Om = np.asarray(Omega, dtype=np.float64)
    absOm = np.abs(Om)
    absOm = np.where(absOm == 0.0, 1e-16, absOm)
    denom = np.sqrt(8.0 * np.pi * absOm * np.sinh(np.pi * absOm))
    return np.exp(-0.5 * np.pi * Om) / denom

def M_unified(Omega, Omega_p, chi, chi_p, a, T, omega):
    """
    M_{χχ'} ∝ χχ' ΩΩ' a^{i(Ω+Ω')} f(-χΩ) f(-χ'Ω')
             exp[-(aT)^2/4 (χΩ+χ'Ω')^2] w( T[(a/2)(χΩ-χ'Ω') - ω] )
    """
    g = 1.0
    hbar = 1.0

    Omega   = np.asarray(Omega,   dtype=np.float64)
    Omega_p = np.asarray(Omega_p, dtype=np.float64)

    prefactor = (g**2 * a**2) / (2.0 * hbar**2) * chi * chi_p * Omega * Omega_p
    f_product = f_norm(-chi * Omega) * f_norm(-chi_p * Omega_p)
    phase_factor = np.exp(1j * (Omega + Omega_p) * np.log(a))
    gaussian_term = np.exp(-0.25 * (a**2) * (T**2) * (chi*Omega + chi_p*Omega_p)**2)

    faddeeva_arg = T * (0.5 * a * (chi*Omega - chi_p*Omega_p) - omega)
    faddeeva_term = wofz(faddeeva_arg.astype(np.complex128))

    return prefactor * phase_factor * f_product * gaussian_term * faddeeva_term

# --- Main Plotting Script ---

omega = 1.0
T = 5.0
accelerations = [1.0, 0.5, 0.1, 0.01]
axis_ranges   = [3,    6,   15,  120]

n_points = 500

fig, axs = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle(r'\textbf{RR Emission Spectrum at Different Scales of Acceleration}', fontsize=18, y=0.96)

subplot_labels = ['(a)', '(b)', '(c)', '(d)']

for i, (a_val, plot_range) in enumerate(zip(accelerations, axis_ranges)):
    ax = axs.flat[i]

    Omega_vals = np.linspace(-plot_range, plot_range, n_points)
    Op_vals    = np.linspace(-plot_range, plot_range, n_points)
    O_grid, Op_grid = np.meshgrid(Omega_vals, Op_vals, indexing='xy')

    # RR channel (χ=+1, χ'=+1)
    Z_RR = M_unified(O_grid, Op_grid, 1, 1, a=a_val, T=T, omega=omega)
    Prob_RR = np.abs(Z_RR)**2

    # ---- key fixes ----
    # mask non-positive + invalid (NaN/Inf) for LogNorm
    Prob_masked = np.ma.masked_less_equal(Prob_RR, 0.0)
    Prob_masked = np.ma.masked_invalid(Prob_masked)

    vmax = np.nanmax(Prob_masked) if np.ma.is_masked(Prob_masked) else np.max(Prob_masked)
    if not np.isfinite(vmax):
        vmax = 1.0
    vmin = vmax * 1e-6

    my_cmap = cm.get_cmap('jet').copy()  # or 'viridis'
    my_cmap.set_under('darkblue')  # below vmin
    my_cmap.set_bad('darkblue')    # masked

    log_levels = np.logspace(np.log10(vmin), np.log10(vmax), num=100)

    # set the axes background to match the "under/bad" color,
    # so any truly empty region blends in instead of white
    ax.set_facecolor('darkblue')

    im = ax.contourf(
        O_grid, Op_grid, Prob_masked,
        levels=log_levels,
        cmap=my_cmap,
        norm=LogNorm(vmin=vmin, vmax=vmax),
        extend='min'
    )

    cbar = fig.colorbar(im, ax=ax, shrink=0.85)
    if i == 3:
        cbar.set_label(r'\textbf{Probability Density} $|\mathcal{M}|^2$')

    peak_loc = omega / a_val
    ax.set_title(fr'$\mathbf{{{subplot_labels[i]}\ a = {a_val}}}$'
                 fr' \quad (Peak at $\Omega \approx \pm {peak_loc:.0f}$)')

    if i in [0, 2]:
        ax.set_ylabel(r"$\mathbf{\Omega'}$")
    if i in [2, 3]:
        ax.set_xlabel(r'$\mathbf{\Omega}$')

    ax.axhline(0, color='white', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.axvline(0, color='white', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.set_aspect('equal')

plt.tight_layout(rect=[0, 0, 1, 0.95])

output_filename = 'multiscale_blue_background.png'
plt.savefig(output_filename, dpi=300, bbox_inches='tight')
print(f"✅ Final plot saved as '{output_filename}'")
